//Execute.scala
package test

import chisel3._
import chisel3.util._


class Execute extends Module{
  val io = IO(new Bundle{
    val idex = Input(new IDEX)
    val exmem = Output(new EXMEM)

    val RD_ex = Output(new ReDirection)
    val Bubble_ex = Output(Bool())
  })

  val exs = RegNext(io.idex)

  val load = (exs.Inst === BitPat("b???????_?????_?????_010_?????_0000011"))
  val store = (exs.Inst === BitPat("b???????_?????_?????_010_?????_0100011"))
  val load_offset = exs.simm
  val store_offset = SignExt(Cat(exs.Inst(31,25), exs.Inst(11,7)), 32)

  io.Bubble_ex := load

  val alu = new ALU
  val shifts =
    Mux(exs.aluop(1), alu.sll(exs.src1, exs.src2(4,0)), 0.U(32.W)) |
      Mux(exs.aluop(2), alu.sll(exs.src1, exs.shamt), 0.U(32.W)) |
      Mux(exs.aluop(3), alu.srl(exs.src1, exs.src2(4,0)), 0.U(32.W)) |
      Mux(exs.aluop(4), alu.srl(exs.src1, exs.shamt), 0.U(32.W)) |
      Mux(exs.aluop(5), alu.sra(exs.src1, exs.src2(4,0)), 0.U(32.W)) |
      Mux(exs.aluop(6), alu.sra(exs.src1, exs.shamt), 0.U(32.W))
  val arithmetic =
    Mux(exs.aluop(8), alu.add(exs.src1, exs.src2), 0.U(32.W)) |
      Mux(exs.aluop(9), alu.add(exs.src1, exs.simm), 0.U(32.W)) |
      Mux(exs.aluop(10), alu.sub(exs.src1, exs.src2), 0.U(32.W)) |
      Mux(exs.aluop(11), SignExt(exs.Inst(31,12) << 0x0c, 32).asUInt, 0.U(32.W)) |
      Mux(exs.aluop(12), (exs.PC + SignExt(exs.Inst(31,12) << 0x0c, 32).asUInt), 0.U(32.W))
  val logical =
    Mux(exs.aluop(16), alu.xor(exs.src1, exs.src2), 0.U(32.W)) |
      Mux(exs.aluop(17), alu.xor(exs.src1, exs.simm), 0.U(32.W)) |
      Mux(exs.aluop(18), alu.or(exs.src1, exs.src2), 0.U(32.W)) |
      Mux(exs.aluop(19), alu.or(exs.src1, exs.simm), 0.U(32.W)) |
      Mux(exs.aluop(20), alu.and(exs.src1, exs.src2), 0.U(32.W)) |
      Mux(exs.aluop(21), alu.and(exs.src1, exs.simm), 0.U(32.W))
  val compareM =
    Mux(exs.aluop(24), alu.slt(exs.src1, exs.src2), 0.U(32.W)) |
      Mux(exs.aluop(25), alu.slt(exs.src1, exs.simm), 0.U(32.W)) |
      Mux(exs.aluop(26), alu.sltu(exs.src1, exs.src2), 0.U(32.W)) |
      Mux(exs.aluop(27), alu.sltu(exs.src1, exs.simm), 0.U(32.W)) |
      Mux(exs.aluop(31), alu.mul(exs.src1, exs.src2), 0.U(32.W))
  val address =
    Mux(load, alu.add(exs.src1, load_offset), 0.U(32.W)) |
      Mux(store, alu.add(exs.src1, store_offset), 0.U(32.W))
  val aluresult = shifts | arithmetic | logical | compareM | address

  io.exmem.Valid := exs.Valid
  io.exmem.Inst := exs.Inst
  io.exmem.PC := exs.PC
  io.exmem.aluresult :=  aluresult
  io.exmem.wen := exs.wen
  io.exmem.regdes := exs.regdes
  io.exmem.load := load
  io.exmem.store := store
  io.exmem.storeData := exs.src2

  io.RD_ex.wen := io.exmem.wen
  io.RD_ex.regdes := exs.regdes
  io.RD_ex.wbdata := aluresult

}
